"""
研究transfer对
x轴移动和旋转的影响。
"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models import encoders, decoders
from disvae.models.anneal import *
from disvae.models.losses import BtcvaeLoss, get_loss_s, _kl_normal_loss, _reconstruction_loss
from exps.models import *
from exps.visualize import *
from exps.patterns import data_generator
import numpy as np
from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=5e-4,
    epochs=441,
    beta=5,
    dimension=3,
    img_id=7,
    random_seed=345,
    file=__file__.split('/')[-1]
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()

wandb.init(project="exps", config=args)
config = args
set_seed(config.random_seed)


def train(dl, model, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    inc = 20 / len(dl) / epochs
    C = 0
    for e in trange(epochs):
        storer = defaultdict(list)

        for img, label in dl:
            bs = len(img)
            img = img.view(-1, 1, 64, 64).cuda()
            label = ((label - 19) / torch.Tensor([11.5470, 11.5470, 1])).cuda()
            mean, logvar = encoder(img)
            latent_sample = reparameterize(mean, logvar)
            # supervised = F.mse_loss(latent_sample,label)

            latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)
            kl = config.beta * (latent_kl.sum()) / (1 + C)
            # torch.where(latent_kl<C,config.beta *latent_kl,100*latent_kl).sum()
            C += inc

            recon = decoder(latent_sample)

            recon_loss = _reconstruction_loss(img, recon)
            loss = recon_loss + kl

            opt.zero_grad()
            loss.backward()
            opt.step()

            for i in range(dim):
                storer[f'kl_{i}'].append(latent_kl[i].item())
            # storer['supervised'].append(supervised.item())

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if (e + 1) % 80 == 0:
            storer['C'] = C
            for i in range(dim):
                storer[f'z_{i}'] = wandb.Histogram(latent_sample[:, i].data.cpu().numpy())
            storer['recon'] = wandb.Image(recon[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
            _, index = kl_loss.sort(descending=True)

            preds, targets = get_preds(model, dl)
            preds, targets = preds[:1000, index[:3]], targets[:1000]
            fig = gt_vs_latent(preds.cpu(), targets.cpu())
            storer['correlated'] = wandb.Image(fig)

            # model.cpu()
            # model.eval()
            # fig = plt_sample_traversal(None, model, 7, index, r=2)
            # storer['traversal'] = wandb.Image(fig)
            # model.cuda()
            # model.train()

        # storer['epoch'] = e
        wandb.log(storer, step=e, sync=True)
        # plt.close()
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')

imgs, labels = data_generator.gen_rotation(img, img.shape)

epochs = config.epochs
dim = config.dimension

dataset = imgs[:, :, 19].reshape(-1, 1, 64, 64)
nlabels = labels[:, :, 19].reshape(-1, 3)

encoder = encoders.EncoderBurgess((1, 64, 64), dim)
# encoder.load_state_dict(torch.load('opt_encoder.pth'))
decoder = decoders.DecoderBurgess((1, 64, 64), dim)
decoder.load_state_dict(torch.load('opt_decoder.pth'))
decoder.requires_grad_(False)
model = nn.Sequential()
model.add_module('encoder', encoder)
model.add_module('decoder', decoder)

model.train()
model.cuda()

dl = DataLoader(list(zip(dataset, nlabels)), pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

storer = train(dl, model, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
model.eval()
model.cpu()

preds, targets = get_preds(model, dl)
points = preds[:, index[:3]].cpu()
storer['point_cloud'] = wandb.Object3D(points.numpy())

fig = plt_sample_traversal(None, model, 7, index, r=2)
storer['traversal'] = wandb.Image(fig)

wandb.log(storer)
plt.show()

# torch.save(model.state_dict(),'optimal.pth')
